#!/usr/bin/env python
# encoding: utf-8
'''
@author: wangjianrong
@software: pycharm
@file: vis_taget_labels.py
@time: 2020/9/27 17:24
@desc:
'''
import os
import cv2
import matplotlib.pyplot as plt
import numpy as np

root_folder = '/workspace_wjr/shm/dataset/cityscapes/ori_dataset/'
train_folder = root_folder + 'train/'
val_folder = root_folder + 'val/'
test_folder = root_folder + 'test/'

phase = 'train'
vis_folder = root_folder + phase + '/'
image_folder = vis_folder + 'images/'
label_folder = vis_folder + 'labels/'

target_label = 7
list_imgs = os.listdir(image_folder)
for img_name in list_imgs:
    img_path = image_folder + img_name
    label_name = os.path.splitext(img_name)[0] + '.png'
    label = cv2.imread(label_folder+label_name,-1)
    img = cv2.imread(image_folder+img_name,-1)
    mask_target = label==target_label
    print(np.sum(mask_target))
    if np.sum(mask_target) == 0:
        continue
    img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
    plt.imshow(img)
    img[mask_target] = (0,0,255)
    plt.imshow(img)
    plt.show()
    plt.pause(5)




